img_names = os.listdir(data_path)
source = img_names
alphabet = ''.join(source)
# 讀取圖檔,並轉換大小為80*80,以及轉換成RGB
def img_loader(img_path):
image = Image.open(img_path)
img = image.resize((80, 80),Image.ANTIALIAS) #resize image with high-quality
return img.convert('RGB')
# 將圖檔與label對應,丟入自定義的資料集內
def make_dataset(data_path, alphabet, num_class):
samples = []
for i in os.listdir(data_path):
for j in os.listdir(data_path + '/' + i):
img_path = data_path + '/' + i + '/' + j
target_str = j.split('.')[0][-1]
vec = [0] * 800
vec[alphabet.find(target_str)] = 1
target = vec
samples.append((img_path, target))
return samples
例如這個字是"不",由alphabet的位置可以看到alphabet[3]的位置是"不",故在alphabet[3]的位置為1,代表他的label,其餘位置皆為0。
torch.utils.data.Dataset,是一個自定義資料集的框架。
__ init __()
def __init__(self, data_path, num_class=800,transform=None,target_transform=None, alphabet=alphabet):
super(Dataset, self).__init__()
self.data_path = data_path
self.num_class = num_class
self.transform = transform
self.target_transform = target_transform
self.alphabet = alphabet
self.samples = make_data.set(self.data_path, self.alphabet)
__ len __ ()
def __len__(self):
return len(self.samples)
__ getitem __ ()
def __getitem__(self, index):
img_path, target = self.samples[index]
img = img_loader(img_path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, torch.Tensor(target) # 在torch裡面,array都要轉成Tensor型式
完整程式碼
class CaptchaData(Dataset):
def __init__(self, data_path, num_class=800,
transform=None, target_transform=None, alphabet=alphabet):
super(Dataset, self).__init__()
self.data_path = data_path
self.num_class = num_class
self.transform = transform
self.target_transform = target_transform
self.alphabet = alphabet
self.samples = make_dataset(self.data_path, self.alphabet)
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
img_path, target = self.samples[index]
img = img_loader(img_path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, torch.Tensor(target)
torch.utils.data.DataLoader
Dataset設置好後,DataLoader可以依照batch_size讓我們取樣,非常方便。
from torchvision.transforms import Compose, ToTensor
from torch.utils.data import DataLoader
transforms = Compose([ToTensor()])
train_dataset = CaptchaData(r'C:\Users\Frank\PycharmProjects\practice\mountain\data_final_20210530\official_in_800',transform=transforms)
train_data_loader = DataLoader(train_dataset, batch_size=1, num_workers=0,
shuffle=True, drop_last=True)
for (data,label) in train_data_loader:
print((data,label))
我把batch_size設定為1,他一次就只取出一組圖片樣本及標籤。
除了自定義資料集以外,還有可以torchvision.datasets.ImageFolder
来處理資料集,用法會在於你分好類別,他的資料夾名稱就是他的label,而裡面圖片都屬於這個label。
深度學習有很多很好玩的地方,但也有很多的坑,debug我都要找很久XDD,重點是東西太多,絕對學不完,而且很吃硬體設備。有時候會覺得自己好笨,都學不會,但看久了發現懂一點了,就又有動力繼續往下學了,接觸深度學習的朋友們,我們一起繼續努力吧!
小弟我是試著用自定義資料集來處理,原因只想練習以及可以更彈性的操作載入資料的動作。
前面加載圖片時我們把transforms設置為None,現在我們丟模型訓練要對圖片做transforms,他可以增加圖片的多樣性,例如:旋轉、平移、變形等等,明天來跟大家分享torchvision很好用的套件transforms。